-
Notifications
You must be signed in to change notification settings - Fork 33
Add CoupledTrainStepper
#809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ) -> StepLossABC: | ||
| if self.n_steps == 0 or self.weight == 0.0: | ||
| return NullLossContributions() | ||
| return NullLossContributions(loss_obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This preserves the existing behavior where we used a component stepper's effective_loss_scaling to compute mse_fractional_components metrics even if the stepper had no loss contribution in coupled training.
mcgibbon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nits (nits are optional), I don't need to re-review them. LGTM
fme/coupled/test_loss.py
Outdated
|
|
||
| @property | ||
| def effective_loss_scaling(self): | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| raise NotImplementedError | |
| raise NotImplementedError() |
fme/coupled/test_loss.py
Outdated
| atmos_loss_config = LossContributionsConfig() | ||
| atmosphere_loss = atmos_loss_config.build( | ||
| loss_obj=lambda *_, **__: torch.tensor(5.25), | ||
| loss_obj=Mock(spec=StepLoss, side_effect=lambda *_, **__: torch.tensor(5.25)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The three lines changed use three different ways to specify the loss side-effect - via mae_loss, via a lambda function returning a constant, and via a return_value instead of a side_effect. You could consider using return_value for this one to reduce that down to 2 ways, at least.
| n_samples=3, | ||
| ) | ||
| output = coupler.train_on_batch( | ||
| train_stepper_config = CoupledTrainStepperConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Avoid the 3x copy-paste of this process by making a get_train_stepper_and_batch helper that does it and calls get_stepper_and_batch internally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, but I'll defer this cleanup to #814 since the way in which the train stepper is built is going to change.
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class CoupledTrainStepperConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example of the updated training config committed somewhere I could check out? It would be nice to have a baseline config for coupled training, if so I could see the changes to the baseline in this PR.
Update: Ah I see test_train.py mostly fits this purpose, good. Still, could be nice to have a baseline in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I will work on a new PR to add the baseline.
fme/coupled/test_train.py
Outdated
| loss_contributions: | ||
| n_steps: {loss_atmos_n_steps} | ||
| stepper: | ||
| loss: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Why is loss: type: MSE in both the atmosphere: stepper: and in the train_stepper: atmosphere:? I am guessing because we haven't updated the ACE configs yet and it's required in the config, in which case that's fine, but I though I should ask to be sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, although including it in the yaml here isn't strictly necessary since there is a default value on StepperConfig. I'll remove it so it's a bit clearer here.
fme/coupled/train/train.py
Outdated
| atmosphere_normalize=stepper.atmosphere.normalizer.normalize, | ||
| ocean_loss_scaling=stepper.ocean.effective_loss_scaling, | ||
| atmosphere_loss_scaling=stepper.atmosphere.effective_loss_scaling, | ||
| ocean_loss_scaling=stepper.effective_loss_scaling.ocean, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: pass loss_scaling: stepper.loss_scaling instead of two arguments containing the parts
Adds
train_stepper: CoupledTrainStepperConfigto the coupled training config, which configures and builds aCoupledTrainStepperimplementingTrainStepperABC.WARNING: This is a breaking change for existing coupled training configs.
Changes:
Component stepper
loss: StepLossConfigandloss_contributions: LossContributionsConfigare now configured via theocean: ComponentTrainingConfigandatmosphere: ComponentTrainingConfigattributes ofCoupledTrainStepperConfig.CoupledStepperno longer implementsTrainStepperABC.Removed public
loss_objandeffective_loss_scalingproperties fromfme.ace.stepper.Stepperand added a new public methodbuild_loss.Tests added